import os

import numpy as np
import pytest
import torch
from pytest import approx, raises

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.regression import MeanAbsoluteRelativeError


def test_wrong_input_shapes():
    m = MeanAbsoluteRelativeError()

    with raises(ValueError, match=r"Input data shapes should be the same, but given"):
        m.update((torch.rand(4), torch.rand(4, 1)))

    with raises(ValueError, match=r"Input data shapes should be the same, but given"):
        m.update((torch.rand(4, 1), torch.rand(4)))


def test_mean_absolute_relative_error(available_device):
    a = torch.rand(4)
    b = torch.rand(4)
    c = torch.rand(4)
    d = torch.rand(4)
    ground_truth = torch.rand(4)

    m = MeanAbsoluteRelativeError(device=available_device)
    assert m._device == torch.device(available_device)

    m.update((a, ground_truth))
    abs_error_a = torch.sum(torch.abs(ground_truth - a) / torch.abs(ground_truth))
    num_samples_a = a.size()[0]
    sum_error = abs_error_a
    sum_samples = num_samples_a
    MARE_a = sum_error / sum_samples
    assert m.compute() == approx(MARE_a.item())

    m.update((b, ground_truth))
    abs_error_b = torch.sum(torch.abs(ground_truth - b) / torch.abs(ground_truth))
    num_samples_b = b.size()[0]
    sum_error += abs_error_b
    sum_samples += num_samples_b
    MARE_b = sum_error / sum_samples
    assert m.compute() == approx(MARE_b.item())

    m.update((c, ground_truth))
    abs_error_c = torch.sum(torch.abs(ground_truth - c) / torch.abs(ground_truth))
    num_samples_c = c.size()[0]
    sum_error += abs_error_c
    sum_samples += num_samples_c
    MARE_c = sum_error / sum_samples
    assert m.compute() == approx(MARE_c.item())

    m.update((d, ground_truth))
    abs_error_d = torch.sum(torch.abs(ground_truth - d) / torch.abs(ground_truth))
    num_samples_d = d.size()[0]
    sum_error += abs_error_d
    sum_samples += num_samples_d
    MARE_d = sum_error / sum_samples
    assert m.compute() == approx(MARE_d.item())


def test_zero_div():
    a = torch.tensor([2.0, -1.0, -1.0, 2.0])
    ground_truth = torch.tensor([0.0, 0.5, 0.2, 1.0])

    m = MeanAbsoluteRelativeError()
    with raises(NotComputableError, match=r"The ground truth has 0"):
        m.update((a, ground_truth))


def test_zero_sample():
    m = MeanAbsoluteRelativeError()
    with raises(NotComputableError, match=r"MeanAbsoluteRelativeError must have at least one sample"):
        m.compute()


@pytest.mark.parametrize("n_times", range(5))
@pytest.mark.parametrize(
    "test_cases",
    [
        (torch.rand(size=(100,)), torch.rand(size=(100,)), 10),
        (torch.rand(size=(100, 1)), torch.rand(size=(100, 1)), 20),
    ],
)
def test_integration_mean_absolute_relative_error(n_times, test_cases, available_device):
    y_pred, y, batch_size = test_cases

    def update_fn(engine, batch):
        idx = (engine.state.iteration - 1) * batch_size
        y_true_batch = y[idx : idx + batch_size]
        y_pred_batch = y_pred[idx : idx + batch_size]
        return y_pred_batch, y_true_batch

    engine = Engine(update_fn)

    m = MeanAbsoluteRelativeError(device=available_device)
    assert m._device == torch.device(available_device)
    m.attach(engine, "mare")

    data = list(range(y_pred.shape[0] // batch_size))
    mare = engine.run(data, max_epochs=1).metrics["mare"]

    eps = 1e-20  # to avoid division by zero
    res = torch.abs(y_pred - y) / (torch.abs(y) + eps)
    expected = res.mean().item()

    assert mare == pytest.approx(expected)


def _test_distrib_compute(device):
    rank = idist.get_rank()

    def _test(metric_device):
        metric_device = torch.device(metric_device)
        m = MeanAbsoluteRelativeError(device=metric_device)

        y_pred = torch.randint(1, 11, size=(10,), device=device).float()
        y = torch.randint(1, 11, size=(10,), device=device).float()

        m.update((y_pred, y))

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = y_pred.cpu().numpy()
        np_y = y.cpu().numpy()

        res = m.compute()

        abs_error = np.sum(abs(np_y - np_y_pred) / abs(np_y))
        num_samples = len(y_pred)
        np_res = abs_error / num_samples

        assert np_res == approx(res)

    for i in range(3):
        torch.manual_seed(10 + rank + i)
        _test("cpu")
        if device.type != "xla":
            _test(idist.device())


def _test_distrib_integration(device):
    rank = idist.get_rank()

    def _test(n_epochs, metric_device):
        metric_device = torch.device(metric_device)
        n_iters = 80
        batch_size = 16

        y_true = torch.rand(size=(n_iters * batch_size,)).to(device)
        y_preds = torch.rand(size=(n_iters * batch_size,)).to(device)

        def update(engine, i):
            return (
                y_preds[i * batch_size : (i + 1) * batch_size],
                y_true[i * batch_size : (i + 1) * batch_size],
            )

        engine = Engine(update)

        m = MeanAbsoluteRelativeError(device=metric_device)
        m.attach(engine, "mare")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        y_preds = idist.all_gather(y_preds)
        y_true = idist.all_gather(y_true)

        assert "mare" in engine.state.metrics

        mare = engine.state.metrics["mare"]

        np_y_true = y_true.cpu().numpy()
        np_y_preds = y_preds.cpu().numpy()

        abs_error = np.sum(abs(np_y_true - np_y_preds) / abs(np_y_true))
        num_samples = len(y_preds)
        np_res = abs_error / num_samples

        assert approx(mare) == np_res

    metric_devices = ["cpu"]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        for i in range(2):
            torch.manual_seed(12 + rank + i)
            _test(n_epochs=1, metric_device=metric_device)
            _test(n_epochs=2, metric_device=metric_device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
    nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

    gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)
    gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
    n = int(os.environ["NUM_TPU_WORKERS"])
    xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)
